# Copyright 2025 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import random
import warnings
from datetime import timedelta
from pathlib import Path
from typing import Any, Literal

import boto3
import botocore
from botocore.config import Config as botocoreConfig
from lightning_fabric.plugins.collectives.torch_collective import default_pg_timeout
from pydantic import BaseModel, field_validator, model_validator
from pydantic import ConfigDict as PydanticConfigDict

from openfold3.core.data.pipelines.preprocessing.template import (
    TemplatePreprocessorSettings,
)
from openfold3.core.data.tools.colabfold_msa_server import MsaComputationSettings
from openfold3.projects.of3_all_atom.config.dataset_configs import (
    InferenceDatasetConfigKwargs,
    TrainingDatasetPaths,
)
from openfold3.projects.of3_all_atom.project_entry import ModelUpdate

logger = logging.getLogger(__name__)

ValidModeType = Literal["train", "predict", "eval", "test"]
DEFAULT_CACHE_PATH = Path("~/.openfold3/").expanduser()
CHECKPOINT_ROOT_FILENAME = "ckpt_root"
CHECKPOINT_NAME = "of3_ft3_v1.pt"


def _maybe_download_parameters(target_path: Path) -> None:
    """Checks if OpenFold parameters are present, and downloads them if not."""
    openfold_bucket = "openfold"
    checkpoint_path = f"openfold3_params/{CHECKPOINT_NAME}"

    if target_path.exists():
        return

    s3 = boto3.client("s3", config=botocoreConfig(signature_version=botocore.UNSIGNED))

    try:
        # Get file size
        response = s3.head_object(Bucket=openfold_bucket, Key=checkpoint_path)
        size_bytes = response["ContentLength"]
        size_gb = size_bytes / (1024**3)

        # Ask for confirmation with file size
        confirm = input(
            f"Download {checkpoint_path} ({size_gb:.2f} GB) "
            f"from s3://{openfold_bucket} to {target_path}? (yes/no): "
        )

        if confirm.lower() in ["yes", "y"]:
            logger.info(f"Downloading to {target_path}...")
            s3.download_file(openfold_bucket, checkpoint_path, target_path)
            logger.info("Download complete!")
        else:
            logger.warning("Download cancelled")

    except Exception as e:
        print(f"Error: {e}")


class CheckpointConfig(BaseModel):
    """Settings for training checkpoint writing."""

    every_n_epochs: int = 1
    auto_insert_metric_name: bool = False
    save_last: bool = True
    save_top_k: int = -1


class WandbConfig(BaseModel):
    """Configuration for Weights and Biases experiment result logging."""

    project: str | None = None
    experiment_name: str | None = None
    entity: str | None = None
    group: str | None = None
    id: str | None = None
    offline: bool = False


class LoggingConfig(BaseModel):
    """Settings for training logging."""

    log_lr: bool = True
    log_grads: bool = False
    log_level: Literal["debug", "info", "warning", "error"] | None = None
    wandb_config: WandbConfig | None = None


class DataModuleArgs(BaseModel):
    """Settings for openfold3.core.data.framework.data_module"""

    model_config = PydanticConfigDict(extra="forbid")
    batch_size: int = 1
    data_seed: int | None = None
    num_workers: int = 10
    num_workers_validation: int = 4
    epoch_len: int = 4


class PlTrainerArgs(BaseModel):
    """Arguments to configure pl.Trainer, including settings for number of devices."""

    model_config = PydanticConfigDict(extra="allow")
    max_epochs: int = 1000  # pl_trainer default
    accelerator: str = "gpu"
    precision: int | str = "32-true"
    num_nodes: int = 1
    devices: int = 1  # number of GPUs per node
    profiler: str | None = None
    log_every_n_steps: int = 1
    enable_checkpointing: bool = True
    enable_model_summary: bool = False

    # Extra arguments that are not passed directly to pl.Trainer
    deepspeed_config_path: Path | None = None
    distributed_timeout: timedelta | None = default_pg_timeout
    mpi_plugin: bool = False


class OutputWritingSettings(BaseModel):
    """File formats to use for writing inference prediction results.

    Used by OF3OutputWriter in openfold3.core.runners.writer
    """

    structure_format: Literal["pdb", "cif"] = "cif"
    full_confidence_output_format: Literal["json", "npz"] = "json"
    write_features: bool = False
    write_latent_outputs: bool = False


class ExperimentSettings(BaseModel):
    """General settings for all experiments"""

    mode: ValidModeType
    output_dir: Path = Path("./")
    log_dir: Path | None = None

    @field_validator("output_dir", mode="after")
    def create_output_dir(cls, value: Path):
        if not value.exists():
            value.mkdir(parents=True, exist_ok=True)
        return value


class CheckpointLoadingSettings(BaseModel):
    """
    Provides more granular control over checkpoint loading.
    While the standard PL process restores the entire training state,
    these settings allow for selective loading of specific components.
    """

    manual_checkpoint_loading: bool = False
    init_from_ema_weights: bool = False
    restore_lr_scheduler: bool = False
    restore_time_step: bool = False
    strict_loading: bool = True


class TrainingExperimentSettings(ExperimentSettings):
    """General settings specific for training experiments"""

    mode: ValidModeType = "train"
    seed: int = 42
    restart_checkpoint_path: str | None = None
    preemption_safe_resume: bool = False
    ckpt_load_settings: CheckpointLoadingSettings = CheckpointLoadingSettings()

    @field_validator("restart_checkpoint_path", mode="before")
    def validate_checkpoint_path(cls, value: Any) -> str | None:
        """
        Validates the restart_checkpoint_path.

        The path can be one of the following:
        - None (if no checkpoint is provided).
        - A special string: "last", "hpc", "registry" accepted by PL.
        - A string representing a valid path to a file.
        - A string representing a valid path to a directory (for deepspeed checkpoints).
        """
        # PL accepted strings
        allowed_strings = ["last", "hpc", "registry"]
        allowed_values = allowed_strings + [None]

        if value not in allowed_values and not Path(value).exists():
            raise ValueError(
                f'"{value}" is not a valid file, directory, or accepted keyword '
                f"({', '.join(allowed_strings)})"
            )
        return value

    @model_validator(mode="after")
    def validate_ckpt_load_settings(self):
        manual_settings_enabled = any(
            [
                self.ckpt_load_settings.init_from_ema_weights,
                self.ckpt_load_settings.restore_lr_scheduler,
                self.ckpt_load_settings.restore_time_step,
            ]
        )
        if (
            not self.ckpt_load_settings.manual_checkpoint_loading
            and manual_settings_enabled
        ):
            raise ValueError(
                "If any manual checkpoint loading settings are enabled, "
                "manual_checkpoint_loading must be set to True."
            )
        if (
            self.restart_checkpoint_path is None
            and self.ckpt_load_settings.manual_checkpoint_loading
        ):
            raise ValueError(
                "If manual_checkpoint_loading is set to True, "
                "restart_checkpoint_path must be provided."
            )

        return self


def generate_seeds(start_seed, num_seeds):
    """Helper function for generating random seeds."""
    random.seed(start_seed)
    return [random.randint(0, 2**32 - 1) for _ in range(num_seeds)]


class InferenceExperimentSettings(ExperimentSettings):
    """General settings specific for inference experiments"""

    mode: ValidModeType = "predict"
    seeds: int | list[int] = [42]
    num_seeds: int | None = None
    use_msa_server: bool = False
    use_templates: bool = False
    skip_existing: bool = False

    @model_validator(mode="after")
    def generate_seeds(self):
        """Creates a list of seeds if a list of seeds is not provided."""
        if isinstance(self.seeds, list):
            pass
        elif isinstance(self.seeds, int):
            if self.num_seeds is None:
                raise ValueError(
                    "Attempted to generate seeds using starting"
                    f" seed {self.seeds} but num_seeds was not provided."
                    "Please either provide `num_seeds` or a list of seeds."
                )
            self.seeds = generate_seeds(self.seeds, self.num_seeds)
        elif self.seeds is None:
            raise ValueError("seeds must be provided (either int or list[int])")

        return self


class ExperimentConfig(BaseModel):
    """Base set of arguments expected for all experiments"""

    experiment_settings: ExperimentSettings
    pl_trainer_args: PlTrainerArgs = PlTrainerArgs()
    model_update: ModelUpdate


class TrainingExperimentConfig(ExperimentConfig):
    """Training experiment config"""

    # pydantic model setting to prevent extra fields in main experiment config
    model_config = PydanticConfigDict(extra="forbid")
    # required arguments for training experiment
    dataset_paths: dict[str, TrainingDatasetPaths]
    dataset_configs: dict[str, Any]

    experiment_settings: TrainingExperimentSettings = TrainingExperimentSettings()
    logging_config: LoggingConfig = LoggingConfig()
    checkpoint_config: CheckpointConfig = CheckpointConfig()
    model_update: ModelUpdate = ModelUpdate(presets=["train"])
    data_module_args: DataModuleArgs = DataModuleArgs()

    @model_validator(mode="after")
    def synchronize_seeds(self):
        """
        Ensures data_seed in DataModuleArgs is set. If it isn't, it will
        default to the model seed.
        """
        model_seed = self.experiment_settings.seed
        data_seed = self.data_module_args.data_seed

        if data_seed is None:
            self.data_module_args.data_seed = model_seed

        return self

    @model_validator(mode="after")
    def check_preemption_safe(self):
        """
        Checks whether preemption_safe_resume settings are valid.
        Currently, this only supports jobs that use wandb logging
        with a set id.

        It will have the following effects if set:
        1. When restarted, the run will resume from the last locally
           saved checkpoint for a given wandb id.
        2. ckpt_load_settings will be disabled if the run
           already exists and has existing checkpoints.
        3. restart_checkpoint_path will be set to "last" if the run
           already exists and has existing checkpoints.
        """
        if not self.experiment_settings.preemption_safe_resume:
            return self

        wandb_config = self.logging_config.wandb_config
        if wandb_config is None:
            raise ValueError(
                "The `preemption_safe_resume` setting currently only supports jobs "
                "run with wandb. Please provide a wandb_config."
            )
        if wandb_config.id is None:
            raise ValueError(
                "The `preemption_safe_resume` setting requires wandb_config.id to "
                "be set. This ensures that if a job is preempted, the new job resumes "
                "from the same id."
            )

        return self


class InferenceExperimentConfig(ExperimentConfig):
    """Inference experiment config"""

    # pydantic model setting to prevent extra fields in main experiment config
    model_config = PydanticConfigDict(extra="forbid")
    # Required inputs for performing inference
    inference_ckpt_path: Path | None = None
    # default location to look for parameters if no ckpt_path is given
    cache_path: Path | None = None

    experiment_settings: InferenceExperimentSettings = InferenceExperimentSettings()
    model_update: ModelUpdate = ModelUpdate(presets=["predict", "pae_enabled"])
    data_module_args: DataModuleArgs = DataModuleArgs()
    dataset_config_kwargs: InferenceDatasetConfigKwargs = InferenceDatasetConfigKwargs()
    output_writer_settings: OutputWritingSettings = OutputWritingSettings()
    msa_computation_settings: MsaComputationSettings = MsaComputationSettings()
    template_preprocessor_settings: TemplatePreprocessorSettings = (
        TemplatePreprocessorSettings(mode="predict")
    )

    @model_validator(mode="before")
    @classmethod
    def set_default_cache_path(cls, data):
        """Set default cache_path if not provided"""
        if data.get("cache_path") is None:
            cache_path = os.environ.get("OPENFOLD_CACHE") or DEFAULT_CACHE_PATH
            Path(cache_path).mkdir(parents=True, exist_ok=True)
            data["cache_path"] = cache_path
        return data

    @model_validator(mode="after")
    def _try_default_ckpt_path(self):
        if (
            isinstance(self.inference_ckpt_path, Path)
            and self.inference_ckpt_path.exists()
        ):
            return self
        elif self.inference_ckpt_path is None:
            # Try using path set in cache
            path_to_ckpt = self.cache_path / CHECKPOINT_ROOT_FILENAME
            if path_to_ckpt.exists():
                with open(path_to_ckpt) as f:
                    param_dir = f.read().strip()
                    self.inference_ckpt_path = Path(param_dir) / CHECKPOINT_NAME
            # If not set, write pararms to default dictionary
            else:
                param_dir = self.cache_path
                logger.info("Storing path to OpenFold parameters in %s", path_to_ckpt)
                with open(path_to_ckpt, "w") as f:
                    f.write(str(param_dir))
                self.inference_ckpt_path = param_dir / CHECKPOINT_NAME
            _maybe_download_parameters(self.inference_ckpt_path)
        else:
            raise ValueError(
                f"Provided checkpoint path {self.inference_ckpt_path} does not exist"
            )
        return self

    @model_validator(mode="after")
    def synchronize_seeds(self):
        """
        Ensures data_seed in DataModuleArgs is set. If it isn't, it will
        default to the first model seed in the provided list.
        """
        model_seeds = self.experiment_settings.seeds
        data_seed = self.data_module_args.data_seed

        if data_seed is None:
            self.data_module_args.data_seed = model_seeds[0]

        return self

    @model_validator(mode="after")
    def copy_ccd_file_path(self):
        """Copies ccd_file_path dataset_config_kwargs>template_preprocessor_settings."""
        if self.dataset_config_kwargs.ccd_file_path is not None:
            if self.template_preprocessor_settings.ccd_file_path is not None:
                warnings.warn(
                    "Overwriting ccd_file_path in template_preprocessor_settings with "
                    "dataset_config_kwargs.ccd_file_path. We recommend specifying"
                    "ccd_file_path only in dataset_config_kwargs.",
                    stacklevel=2,
                )
            self.template_preprocessor_settings.ccd_file_path = (
                self.dataset_config_kwargs.ccd_file_path
            )

        return self
